Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][ROCDL] Remove patterns for ops supported as intrinsics in the AMDGPU backend #102971

Merged
merged 5 commits into from
Sep 4, 2024

Conversation

jsjodin
Copy link
Contributor

@jsjodin jsjodin commented Aug 12, 2024

This patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Jan Leyonberg (jsjodin)

Changes

This patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.


Full diff: https://github.com/llvm/llvm-project/pull/102971.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+4-8)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+17-85)
  • (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (-60)
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 7de6971ba2ee72..fd4eab0e10d67e 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   // Handled by mathToLLVM: math::AbsIOp
+  // Handled by mathToLLVM: math::AbsFIOp
   // Handled by mathToLLVM: math::CopySignOp
   // Handled by mathToLLVM: math::CountLeadingZerosOp
   // Handled by mathToLLVM: math::CountTrailingZerosOp
   // Handled by mathToLLVM: math::CgPopOp
+  // Handled by mathToLLVM: math::ExpOp
   // Handled by mathToLLVM: math::FmaOp
+  // Handled by mathToLLVM: math::LogOp
   // FIXME: math::IPowIOp
   // FIXME: math::FPowIOp
   // Handled by mathToLLVM: math::RoundEvenOp
   // Handled by mathToLLVM: math::RoundOp
+  // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
-  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
-                                   "__ocml_fabs_f64");
   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
                                    "__ocml_acos_f64");
   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                    "__ocml_cosh_f64");
   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
                                    "__ocml_sinh_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
-                                  "__ocml_exp_f64");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
                                    "__ocml_exp2_f64");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
                                     "__ocml_expm1_f64");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
                                     "__ocml_floor_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
-                                  "__ocml_log_f64");
   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
                                     "__ocml_log10_f64");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
@@ -106,8 +104,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                     "__ocml_rsqrt_f64");
   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
                                   "__ocml_sin_f64");
-  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
-                                   "__ocml_sqrt_f64");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
                                    "__ocml_tanh_f64");
   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index bf49a42a115775..4f1f26e8794d9e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -131,21 +131,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_fabs
-  func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.absf %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.absf %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
@@ -206,23 +191,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_exp
-  func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %exp_f32 = math.exp %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result32 = math.exp %exp_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.exp %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -239,21 +207,20 @@ gpu.module @test_module {
 }
 
 // -----
-
 // Test that we handled properly operation with SymbolTable other than module op
 gpu.module @test_module {
   "test.symbol_scope"() ({
     // CHECK: test.symbol_scope
-    // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-    // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-    // CHECK-LABEL: func @gpu_exp
-    func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-      %exp_f32 = math.exp %arg_f32 : f32
-      // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-      %result32 = math.exp %exp_f32 : f32
-      // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-      %result64 = math.exp %arg_f64 : f64
-      // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+    // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
+    // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
+    // CHECK-LABEL: func @gpu_sin
+    func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+      %sin_f32 = math.sin %arg_f32 : f32
+      // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+      %result32 = math.sin %sin_f32 : f32
+      // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+      %result64 = math.sin %arg_f64 : f64
+      // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
       func.return %result32, %result64 : f32, f64
     }
     "test.finish" () : () -> ()
@@ -279,21 +246,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_log
-  func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.log %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.log %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
@@ -359,26 +311,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_sqrt
-  func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
-      -> (f16, f32, f64) {
-    %result16 = math.sqrt %arg_f16 : f16
-    // CHECK: llvm.fpext %{{.*}} : f16 to f32
-    // CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
-    %result32 = math.sqrt %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.sqrt %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result16, %result32, %result64 : f16, f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
@@ -472,15 +404,15 @@ gpu.module @test_module {
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_unroll
   func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
-    %result = math.exp %arg0 : vector<4xf32>
+    %result = math.sin %arg0 : vector<4xf32>
     // CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
     // CHECK: return %[[V4]]
     func.return %result : vector<4xf32>
@@ -526,9 +458,9 @@ gpu.module @test_module {
 
 gpu.module @module {
 // CHECK-LABEL: @spirv_exp
-// CHECK: llvm.call @__ocml_exp_f32
+// CHECK: llvm.call @__ocml_sin_f32
   spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
-    %0 = math.exp %arg0 : vector<4xf32>
+    %0 = math.sin %arg0 : vector<4xf32>
     spirv.ReturnValue %0 : vector<4xf32>
   }
 }
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index a406ec45a7f109..9a05a94f9f1ac7 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -15,21 +15,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
-  // CHECK-LABEL: func @math_absf
-  func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.absf %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.absf %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
@@ -210,21 +195,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-  // CHECK-LABEL: func @math_exp
-  func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.exp %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.exp %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -270,21 +240,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
-  // CHECK-LABEL: func @math_log
-  func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.log %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.log %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
@@ -360,21 +315,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
-  // CHECK-LABEL: func @math_sqrt
-  func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.sqrt %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.sqrt %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, but I want to double-check that nothing'll go wrong with double-precision exp and log

populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note re f64 log

@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
"__ocml_cosh_f64");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
"__ocml_sinh_f64");
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having disassembled OCML, the double-precision exp isn't actually a direct wrapper around the relevant intrinsic, but I figure that's probably fine

Copy link
Contributor

@arsenm arsenm Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We only directly handle the f32 (and I think f16) versions. The f64 versions of the hard operations do not work. We do directly handle llvm.sqrt.f64 as an exception.

Also, none of the trig functions are directly handled (correctly). We do codegen the f32 versions but probably shouldn't

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what do we do in the case of f32 being handled but f64 not, should we still just call ocml for both or modify the lowering to handle just one of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify the lowering to handle just one. The operation + type should be treated like different operations, so emit the working f32 intrinsics and the calls for the nonworking f64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I put the lowering for f64 back for log and exp and added the tests back for them as well.

…AMDGPU backend

This patch removes pattens for a few operations which allows mathToLLVM
conversion to convert the operations into LLVM intrinsics instead since they
are supported directly by the AMDGPU backend.
@jsjodin jsjodin force-pushed the jleyonberg/math-rocdl-update branch from 8fe4b0e to 58f0fc6 Compare September 1, 2024 14:38
@jsjodin jsjodin requested a review from krzysz00 September 1, 2024 15:09
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

(If it's possible, could you get the half-precision functions hooked up in a later PR?)

@jsjodin
Copy link
Contributor Author

jsjodin commented Sep 4, 2024

Approved

(If it's possible, could you get the half-precision functions hooked up in a later PR?)

Sure, that should be pretty easy to do.

@jsjodin jsjodin merged commit 3ebd797 into llvm:main Sep 4, 2024
8 checks passed
nirvedhmeshram added a commit that referenced this pull request Sep 11, 2024
LLVM::FAbsOp and LLVM::SqrtOp are legal after
#102971
nirvedhmeshram added a commit to iree-org/llvm-project that referenced this pull request Sep 11, 2024
LLVM::FAbsOp and LLVM::SqrtOp are legal after
llvm#102971
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Sep 12, 2024
nirvedhmeshram added a commit that referenced this pull request Sep 12, 2024
…108302)

Similar to #108266
After #102971
It is legal to generate `LLVM::ExpOp` and `LLVM::LogOp` if the type is
is a float16 or float32
@jsjodin jsjodin deleted the jleyonberg/math-rocdl-update branch October 1, 2024 14:53
@akuegel
Copy link
Member

akuegel commented Nov 25, 2024

This patch now breaks lowering for bf16 ExpOp or LogOp. For bf16, we upcast to F32, and now we have no lowering for F32 anymore.
Moreover I wonder whether it actually makes a difference in the final lowering? If there are intrinsics, shouldn't the other pattern also eventually lower to them?

@akuegel
Copy link
Member

akuegel commented Nov 25, 2024

@krzysz00 XLA is triggering this bug that bf16 ExpOp and LogOp cannot be lowered anymore: openxla/xla#19700

@krzysz00
Copy link
Contributor

@akuegel You'll want to run --arith-emulate-unsupported-floats to expand out the bf16 operations, I think.

Alternatively, I'll take a followup patch to lower BF16 math.* to the equivalent f32 calls (LLVM intrinsics or OCML ones as relevant) with extf and truncf

@krzysz00
Copy link
Contributor

That is to say, there is no bf16 exp or log and they need to be rewritten to their f32 counterparts

@akuegel
Copy link
Member

akuegel commented Nov 25, 2024

That is to say, there is no bf16 exp or log and they need to be rewritten to their f32 counterparts

Yes, and that is also done in this pass, but due to this PR it says there is no lowering pattern for F32. Reverting this Patch would fix it.

@akuegel
Copy link
Member

akuegel commented Nov 26, 2024

Let me expand a bit. OpToFuncCallLowering (which is the shared logic with the NVVM conversion) upcasts bf16 to f32 before checking whether there is a device function available:

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h#L71

Previously, this worked fine, as when upcasting ExpOp (or LogOp) from bf16 to f32, there was a corresponding f32 function specified. Now, this doesn't work anymore, as there is no corresponding f32 function specified. Users of this pass now need to upcast bf16 ops earlier. Is this really what we want?

Let me repeat my earlier question: Does this patch actually make a difference in the final lowering of math::ExpOp for f32? If there are intrinsics, shouldn't the other pattern also eventually lower to them? And if not, wouldn't the right fix be to change the ocml functions to make use of the intrinsics? @krzysz00 @jsjodin @arsenm maybe one of you can answer this?

@krzysz00
Copy link
Contributor

krzysz00 commented Dec 2, 2024

This patch does change the lowering to use the compiler intrinsics. The OCML functions in question were, for f32, redundant wrappers for the intrinsics that the compiler team wants to remove.

For bf16 exp and log, I'd almost think that an LLVM patch is in order

@akuegel
Copy link
Member

akuegel commented Dec 4, 2024

For XLA, there is now openxla/xla#19913 which should fix the issue by upcasting log and exp ops with bf16 type early.
I will leave it to people from AMD to figure out whether that is the permanent solution or not, I just wanted to point out the (possibly unintended) side effect of this patch.

@krzysz00
Copy link
Contributor

krzysz00 commented Dec 5, 2024

@arsenm and other folks on the compiler side - would there be any reason not to expand exp and log on bfloats during SelectionDAG/GISel?

@arsenm
Copy link
Contributor

arsenm commented Dec 5, 2024

@arsenm and other folks on the compiler side - would there be any reason not to expand exp and log on bfloats during SelectionDAG/GISel?

No. In principle all the math intrinsics should be legalized for all types

@krzysz00
Copy link
Contributor

krzysz00 commented Dec 5, 2024

Ok, so maybe I'm misreading the above comments and we need an MLIR-side patch to make sure that the bf16 math.exp and math.log lower to their LLVM intrinsics

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants